{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd1cdefe-b8c1-40b9-be31-006d52ec9f18",
   "metadata": {
    "tags": [
     "hide"
    ]
   },
   "outputs": [],
   "source": [
    "import seaborn.objects as so\n",
    "from seaborn import load_dataset\n",
    "glue = (\n",
    "    load_dataset(\"glue\")\n",
    "    .pivot(index=[\"Model\", \"Encoder\"], columns=\"Task\", values=\"Score\")\n",
    "    .assign(Average=lambda x: x.mean(axis=1).round(1))\n",
    "    .sort_values(\"Average\", ascending=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "3e49ffb1-8778-4cd5-80d6-9d7e1438bc9c",
   "metadata": {},
   "source": [
    "Add text at x/y locations on the plot:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bf21068-d39e-436c-8deb-aa1b15aeb2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"SST-2\", y=\"MRPC\", text=\"Model\")\n",
    "    .add(so.Text())\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a4b9a8b2-6603-46db-9ede-3b3fb45e0e64",
   "metadata": {},
   "source": [
    "Add bar annotations, horizontally-aligned with `halign`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f68501f0-c868-439e-9485-d71cca86ea47",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"Average\", y=\"Model\", text=\"Average\")\n",
    "    .add(so.Bar())\n",
    "    .add(so.Text(color=\"w\", halign=\"right\"))\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a9d39479-0afa-477b-8403-fe92a54643c9",
   "metadata": {},
   "source": [
    "Fine-tune the alignment using `offset`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5da4a9d-79f3-4c11-bab3-f89da8512ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"Average\", y=\"Model\", text=\"Average\")\n",
    "    .add(so.Bar())\n",
    "    .add(so.Text(color=\"w\", halign=\"right\", offset=6))\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e9c43798-70d5-42b5-bd91-b85684d1b671",
   "metadata": {},
   "source": [
    "Add text above dots, mapping the text color with a third variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d26ebc-24ac-4531-9ba2-fa03720c58bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"SST-2\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n",
    "    .add(so.Dot())\n",
    "    .add(so.Text(valign=\"bottom\"))\n",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f31aaa38-6728-4299-8422-8762c52c9857",
   "metadata": {},
   "source": [
    "Map the text alignment for better use of space:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf4bbf0c-0c5f-4c31-b971-720ea8910918",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"RTE\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n",
    "    .add(so.Dot())\n",
    "    .add(so.Text(), halign=\"Encoder\")\n",
    "    .scale(halign={\"LSTM\": \"left\", \"Transformer\": \"right\"})\n",
    ")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a5de35a6-1ccf-4958-8013-edd9ed1cd4b0",
   "metadata": {},
   "source": [
    "Use additional matplotlib parameters to control the appearance of the text:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c4be188-1614-4c19-9bd7-b07e986f6a23",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    so.Plot(glue, x=\"RTE\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n",
    "    .add(so.Dot())\n",
    "    .add(so.Text({\"fontweight\": \"bold\"}), halign=\"Encoder\")\n",
    "    .scale(halign={\"LSTM\": \"left\", \"Transformer\": \"right\"})\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95fb7aee-090a-4415-917c-b5258d2b298b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "py310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
